from .utils import registry
from lightning import Trainer
from lightning.pytorch.loggers import Logger
from typing import List, Union, Optional
from pydantic import constr, conint


@registry.trainers('lightning')
def build_lightning_trainer(*callbacks, 
                            devices: Union[int, List[int]], 
                            logger: Logger, 
                            strategy: constr(regex="(ddp|fsdp|auto|deepspeed)"),
                            accelerator: constr(regex="(cpu|gpu|tpu|ipu|hpu|mps|auto|cuda)"),
                            precision: constr(regex="(16-mixed|bf16-mixed|32-true|64-true|16|32|64|bf16)"),
                            fast_dev_run: bool = False,
                            max_epochs: int = 10,
                            min_epochs: Optional[int] = None,
                            max_steps: int = -1,
                            min_steps: Optional[int] = None,
                            limit_train_batches: Optional[float] = None,
                            limit_val_batches: Optional[float] = None,
                            accumulate_grad_batches: conint(ge=1, le=10) = 1):
    
    callbacks = list(callbacks)
    
    return Trainer(callbacks=callbacks, 
                   devices=devices, 
                   logger=logger, 
                   precision=precision, 
                   strategy=strategy, 
                   accelerator=accelerator,
                   max_epochs=max_epochs,
                   fast_dev_run=fast_dev_run,
                   limit_train_batches=limit_train_batches,
                   limit_val_batches=limit_val_batches,
                   min_epochs=min_epochs,
                   max_steps=max_steps,
                   min_steps=min_steps,
                   accumulate_grad_batches=accumulate_grad_batches)